from functools import partial
from math import floor
from typing import List, Callable, Optional

import numpy as np

from DataHandling.FileBasedDatasetBase import FileBasedDataset
from Utils.utils import get_model_id_from_file_name
from Utils.Constants import Diff
from Utils import logger
from DataHandling.LSTMSingleFeatureMapBasedDataset import RegressionSingleFeatureMapDataset
from Encoder.FeatureMapAutoEncoder import FeatureMapAutoEncoder


class CombinedRegressionMapsDataset(FileBasedDataset):

    def __init__(self, weights_maps_files: List[str], gradients_maps_files: List[str], results_folders: List[str],
                 result_metric_name: str, pre_load_maps: bool, change_result_values: Optional[Callable[[np.ndarray], np.ndarray]],
                 weights_encoder: FeatureMapAutoEncoder, gradients_encoder: FeatureMapAutoEncoder, use_2_lstms: bool,
                 augment_few_steps_training: Optional[int]):
        """
        Dataset for LSTM using both gradients and weights data as input. It supports both regressor with 1 or 2 LSTM models
        :param weights_maps_files:
        :param gradients_maps_files:
        :param results_folders:
        :param result_metric_name:
        :param pre_load_maps:
        :param change_result_values:
        :param weights_encoder:
        :param gradients_encoder:
        :param use_2_lstms: True if using 2 LSTMs one for each input time, False for 1 LSTM model that uses both sources.
                    True will return results with shape of (sequence size * 2, embedding size),
                    False will return results with shape of (sequence size, embedding size * 2)
        :param augment_few_steps_training: how many steps of data to use for augmenting the dataset.
            if set to None will only return sequences with size of Diff.NUMBER_STEPS_SAVED, otherwise the dataset will
            include data with sequence size of: 1 to augment_few_steps_training and the full size of Diff.NUMBER_STEPS_SAVED.
            e.g. if augment_few_steps_training = 2 and we have 5 data points the dataset will have size if 15. It will
            include datapoint 0 with sizes of: 1, 2 & Diff.NUMBER_STEPS_SAVED,
                    datapoint 1 with sizes of: 1, 2 & Diff.NUMBER_STEPS_SAVED,
                    ...
                    datapoint 4 (last data point) with sizes of: 1, 2 & Diff.NUMBER_STEPS_SAVED.
        """
        super(CombinedRegressionMapsDataset, self).__init__(files_paths=list(), results_files_locations=results_folders,
                                                            result_metric_name=result_metric_name)
        self._files_paths = weights_maps_files
        self._pre_load_maps = pre_load_maps
        self._use_2_lstms = use_2_lstms
        self._augment_few_steps = 0 if augment_few_steps_training is None else augment_few_steps_training + 1
        if self._augment_few_steps:
            self._dataset_size = len(weights_maps_files) * self._augment_few_steps
        else:
            self._dataset_size = len(weights_maps_files)

        gradients_files_map = {get_model_id_from_file_name(file): file for file in gradients_maps_files}
        ordered_gradients_maps = [gradients_files_map[get_model_id_from_file_name(file)] for file in weights_maps_files]
        dataset_func = partial(RegressionSingleFeatureMapDataset, pre_load_maps=pre_load_maps, results_folders=results_folders,
                               change_result_values=change_result_values, result_metric_name=result_metric_name)
        self._weights_dataset = dataset_func(weights_maps_files, auto_encoder=weights_encoder)
        self._gradients_dataset = dataset_func(ordered_gradients_maps, auto_encoder=gradients_encoder)

    @property
    def models_results(self):
        return self._weights_dataset.models_results

    @property
    def models_ids(self):
        return self._weights_dataset.models_ids

    @property
    def files(self):
        return self._weights_dataset.files

    def get_data_by_model_id(self, model_id: str):
        data_idx = self._weights_dataset.models_ids.index(model_id)
        weights_data, y_true = self._weights_dataset[data_idx]
        gradients_data, _ = self._gradients_dataset[data_idx]
        if self._use_2_lstms:
            res = np.concatenate((weights_data, gradients_data))
        else:
            res = np.concatenate((weights_data, gradients_data), axis=1)

        return res, y_true


    def __getitem__(self, index):
        if self._augment_few_steps:
            real_index = floor(index / self._augment_few_steps)
            cut_size = index % self._augment_few_steps
        else:
            real_index = index
            cut_size = 0

        if cut_size == 0:
            cut_size = Diff.NUMBER_STEPS_SAVED

        weights_data, weights_result = self._weights_dataset[real_index]
        gradients_data, _ = self._gradients_dataset[real_index]
        if cut_size < Diff.NUMBER_STEPS_SAVED:
            pad_size = Diff.NUMBER_STEPS_SAVED - cut_size
            weights_data = weights_data[:cut_size, ...]
            weights_data = np.concatenate([weights_data, np.zeros((pad_size, weights_data.shape[1]))])
            gradients_data = gradients_data[:cut_size, ...]
            gradients_data = np.concatenate([gradients_data, np.zeros((pad_size, gradients_data.shape[1]))])

        if self._use_2_lstms:
            return np.concatenate((weights_data, gradients_data)), weights_result
        else:
            return np.concatenate((weights_data, gradients_data), axis=1), weights_result

    @staticmethod
    def create_dataset(weights_maps_files: Optional[List[str]], gradients_maps_files: Optional[List[str]],
                       results_folders: List[str], result_metric_name: str, pre_load_maps: bool, target_mult_10: bool,
                       use_2_lstms: bool, weights_encoder: FeatureMapAutoEncoder, gradients_encoder: FeatureMapAutoEncoder,
                       all_files: Optional[List[str]] = None, augment_few_steps_training: Optional[int] = None):
        if all_files is not None and weights_maps_files is None and gradients_maps_files is None:
            weights_maps_files = list(filter(lambda name: 'weights_' in name, all_files))
            gradients_maps_files = list(filter(lambda name: 'gradients_' in name, all_files))

        logger().log('CombinedRegressionMapsDataset::create_dataset',
                     f'Will create dataset for: {len(weights_maps_files)} files with aug: {augment_few_steps_training}')
        change_result_func = None if not target_mult_10 else lambda vals: vals * 10
        return CombinedRegressionMapsDataset(weights_maps_files=weights_maps_files, gradients_maps_files=gradients_maps_files,
                                             pre_load_maps=pre_load_maps, results_folders=results_folders,
                                             change_result_values=change_result_func, result_metric_name=result_metric_name,
                                             weights_encoder=weights_encoder, gradients_encoder=gradients_encoder,
                                             use_2_lstms=use_2_lstms, augment_few_steps_training=augment_few_steps_training)
